import cv2 as cv2
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize
from scipy.spatial import distance as dist


def track2D(pathname, output):   
    # Performs circle detection in real time
    # output: .avi
    # capture frames from a camera 
    cap = cv2.VideoCapture(pathname) 
    

    # Check if video opened successfully
    if (cap.isOpened()== False): 
        print("Error opening video stream or file")
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH ))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT ))
    fps = cap.get(cv2.CAP_PROP_FPS)

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output, fourcc, fps, (width,height), 1)
    print('reading frames', end='')
    # loop runs if capturing has been initialized 
    while(cap.isOpened()): 
        # reads frames from a camera 
        ret, frame = cap.read() 
        if ret == True:
            print(".", end='')
            # converting BGR to gray 
            img = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 

            R, img, blur, thresh , opening, thinning, circle_RGB, img_circle_RGB = circleExtraction(img)

            # Display an original image 
            #cv2.imshow('Original',frame) 

            # Display edges in a frame 
            #cv2.imshow('Circle',img_circle_RGB) 
            out.write(img_circle_RGB)
        else:
            break
            
        # Wait for Esc key to stop 
        #k = cv2.waitKey(5) & 0xFF
        #if k == 27: 
         #   break


    # Close the window 
    cap.release() 
    out.release()

    # De-allocate any associated memory usage 
    cv2.destroyAllWindows()


def circleExtraction(img, blurSize = 7, thresholdType = 'OTSU', thresholdValue = 127, openingIter = 3, skeletonType = 'GUOHALL'):
    """
    # Curvature radius extraction

    ## Principle

    *Image grayscale -> Thresholding segmentation -> Opening -> Thinning (skeleton extraction) -> circular extraction*

    ## Input

    * png image
    * Median blur kernel size (default 7)
    * Type of Thresholding
        Otsu (default)
        Value threshold (value to be given, default 127)
    * Number of opening iterations (default 3)
    * Skeletton extraction: ZHANGSUEN/GUOHALL algorithme (default GUOHALL)

    ## To do / limitations

    * Computationnal time!
    * Scale calibration
    * Multiple circles/curvatures or non bended parts (as support)

    > Play with index of values, ...

    ## Sources

    * [OpenCV documentation](https://docs.opencv.org/3.1.0/index.html)
    * https://scipy-cookbook.readthedocs.io/items/Least_Squares_Circle.html

    INPUTS
    * img: grayscale image
    * blurSize: Median blur kernel size (default 7)
    * thresholdType: Type of Thresholding
        OTSU (default)
        VALUE: Value threshold (value to be given, default 127)
    * openingIter: Number of opening iterations (default 3)
    * skeletonType: Skeletton extraction: ZHANGSUEN/GUOHALL algorithme (default GUOHALL)

    OUTPUTS
    return(R, img, blur, tresh,opening, thinning, circle_RGB, img_circle_RGB)
    R: float
    img, blur, thresh, opening, thinning: grayscale
    circle, img_circle_RGB: RGB
    """
    ### Threshold segmentation
    blur= img.copy()
    blur = cv2.medianBlur(img,blurSize)
    if(thresholdType == 'OTSU'):
        # Otsu's thresholding
        ret,thresh = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    elif(thresholdType == 'VALUE'):
        # global thresholding
        ret,thresh = cv2.threshold(blur,thresholdValue,255,cv2.THRESH_BINARY)
    else:
        print("Wrong threshold type")
        return(-1)
        
    
    ### Morphmathematics - Opening
    kernel = np.ones((5,5),np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=openingIter)


    ### Skeleton extraction    
    if(skeletonType=='GUOHALL'):
        thinning=cv2.ximgproc.thinning(opening, cv2.ximgproc.THINNING_GUOHALL)
    elif(skeletonType=='ZHANGSUEN'):
        thinning=cv2.ximgproc.thinning(opening, cv2.ximgproc.THINNING_ZHANGSUEN)
    else:
        print("Wrong skeleton type")
        return(-1)
        
        
    ### Circular regression: least square method
    circle=thinning.copy()
    idx = cv2.findNonZero(circle) #Get the coordinates f white pixels
    x = np.array([])
    y = np.array([])
    for point in idx:
        if(point[0][0] >= 5 and point[0][1] >= 5): #Cleaning the indexes
            x=np.append(x,[point[0][0]])
            y=np.append(y,[point[0][1]])
    # coordinates of the barycenter
    x_m = np.mean(x)
    y_m = np.mean(y)
    # calculation of the reduced coordinates
    u = x - x_m
    v = y - y_m
    #Calcluation of the best fitting circle
    center_estimate = x_m, y_m
    center, ier = optimize.leastsq(f_2, center_estimate,args=(x,y))
    xc, yc = center
    Ri      = calc_R(center, x, y)
    R       = Ri.mean()
    residu   = sum((Ri - R)**2)
    residu2  = sum((Ri**2-R**2)**2)

    xc_int=int(xc)
    yc_int=int(yc)
    R_int=int(R)

    circle_RGB=cv2.cvtColor(circle, cv2.COLOR_GRAY2RGB)
    circle_RGB=cv2.circle(circle_RGB, (xc_int,yc_int),R_int,(0,0,255),1)
    img_circle_RGB=cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    img_circle_RGB=cv2.circle(img_circle_RGB, (xc_int,yc_int),R_int,(0,0,255),1)
    
    

    return(R, img, blur, thresh , opening, thinning, circle_RGB, img_circle_RGB)



def threshSeg(img, blurSize = 7, thresholdType = 'VALUE', thresholdValue = 127):
    """
    INPUTS
    * img: grayscale image
    * blurSize: Median blur kernel size (default 7)
    * thresholdType: Type of Thresholding
        OTSU
        VALUE (default): Value threshold (value to be given, default 127)
    OUTPUTS
    * -1 if wrong arguments
    * tresh: grayscale image
    """
    blur= img.copy()
    blur = cv2.medianBlur(img,blurSize)
    if(thresholdType == 'OTSU'):
        # Otsu's thresholding
        ret,thresh = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    elif(thresholdType == 'VALUE'):
        # global thresholding
        ret,thresh = cv2.threshold(blur,thresholdValue,255,cv2.THRESH_BINARY)
    else:
        print("Wrong threshold type")
        return(-1)
        
    return (thresh)


def opening(thresh, openingIter = 3):
    """
    INPUTS
    * thresh: grayscale image
    * openingIter: Number of opening iterations (default 3)
    OUTPUTS
    * opening: grayscale image
    """
    ### Morphmathematics - Opening
    kernel = np.ones((5,5),np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=openingIter)
    return(opening)


def skeleton(opening, skeletonType = 'GUOHALL'):
    """
    INPUTS
    * opening: grayscale image
    * skeletonType: Skeletton extraction: ZHANGSUEN/GUOHALL algorithme (default GUOHALL)
    OUTPUTS
    * -1 if wrong arguments
    * skeleton: grayscale image
    """
    
    if(skeletonType=='GUOHALL'):
        skeleton=cv2.ximgproc.thinning(opening, cv2.ximgproc.THINNING_GUOHALL)
    elif(skeletonType=='ZHANGSUEN'):
        skeleton=cv2.ximgproc.thinning(opening, cv2.ximgproc.THINNING_ZHANGSUEN)
    else:
        print("Wrong skeleton type")
        return(-1)
    return(skeleton)


def coordExtraction(skeleton, edges = 5):
    """
    Extract the coordinates of the non-zero pixels
    INPUTS
    * skeleton: grayscale images
    * egdes (default = 5): number of pixel on the side of the image not to consider
    OUTPUTS
    * x, y: arrays of x and y coordiantes of non-zero pixels    
    """
    height = np.size(skeleton, 0)
    width = np.size(skeleton, 1)
    idx = cv2.findNonZero(skeleton) #Get the coordinates f white pixels
    x = np.array([])
    y = np.array([])
    for point in idx:
        if(point[0][0] >= edges and point[0][0] <= (width-edges) and point[0][1] >= edges and point[0][1] <= (height-edges)): #Cleaning the indexes
            x=np.append(x,[point[0][0]])
            y=np.append(y,[point[0][1]])
    
    return(x,y)
    
def circRegression(x,y):
    """
    Circular regression: least square method
    INPUTS
    * x, y: arrays of x and y coordiantes of non-zero pixels
    OUTPUTS
    * R: radius of circle (float)
    * xc,yc: center of circle (float)
    *residu: squared sum of residus
    """
    # coordinates of the barycenter
    x_m = np.mean(x)
    y_m = np.mean(y)
    # calculation of the reduced coordinates
    u = x - x_m
    v = y - y_m
    #Calcluation of the best fitting circle
    center_estimate = x_m, y_m
    center, ier = optimize.leastsq(f_2, center_estimate,args=(x,y))
    xc, yc = center
    Ri      = calc_R(center, x, y)
    R       = Ri.mean()
    residu   = sum((Ri - R)**2)
    residu2  = sum((Ri**2-R**2)**2)
    
    return(R,xc,yc,residu)


def calc_R(c, x, y):
        """ calculate the distance of each 2D points from the center c=(xc, yc) """
        return np.sqrt((x-c[0])**2 + (y-c[1])**2)

def f_2(c, x,y):
        """ calculate the algebraic distance between the 2D points and the mean circle centered at c=(xc, yc) """
        Ri = calc_R(c, x, y)
        return Ri - Ri.mean()
    
def drawCircle(img, xc, yc, R, rgb = (0,0,255), th =1):
    """
    Circular regression: least square method
    INPUTS
    * img: RGB image
    * xc,yc: center of circle (float)
    * R: radius of circle (float)
    * rgb (default (0,0,255)), th (dafault 1): color and thickness of circle
    OUTPUTS
    * circle: RGB img+ circle
    """
    xc_int=int(xc)
    yc_int=int(yc)
    R_int=int(R)

    circle=cv2.circle(img, (xc_int,yc_int),R_int,rgb,th)
    
    return(circle)

def drawArray(img, x, y, rgb = (0,0,255)):
    """
    Circular regression: least square method
    INPUTS
    * img: RGB image
    * x,y: Arrays of coordiantes
    * rgb (default (0,0,255)): color of circle
    OUTPUTS
    * img_array: RGB img+ array
    """
    img_array=img.copy()
    r, g, b = rgb
    for i in range(x.size):
        x_int=int(x[i])
        y_int=int(y[i])
        img_array[y_int,x_int]=[r,g,b]
    return(img_array)

def scale(x, scaleFactor=4):
    """
    Scale a value or array of values
    INPUT
    * x: value or array (numpy)
    * scale factor (mm^2/pixel)
    OUTPUT
    * y: sclaed value of array of values
    """
    d=(scaleFactor)**(1/2) #Pixel width
    y=d*x
    return y

def getScale(img, size_object = 100):
    """
    get the scale from a known image
    INPUTS
    * img: grayscale img
    * size_object: known size of the rectangular object (mm^2)
    OUTPUT
    *scaleFactor (mm^2/pixel)
    *img_scale: RGB image with rectangle drawn
    """
    
    thresh=threshSeg(img)
    img_scale=opening(thresh)
    
    contours,hierarchy = cv2.findContours(img_scale, 1, 2)
    cnt = contours[0]
    
    rect = cv2.minAreaRect(cnt)
    box = cv2.boxPoints(rect) #get the corners of the rectnagle
    box = np.int0(box)
    img_scale=cv2.cvtColor(img_scale,cv2.COLOR_GRAY2RGB)
    img_scale=cv2.drawContours(img_scale,[box],0,(0,0,255),2)
    
    (tl, tr, br, bl) = box
    dA = dist.euclidean(tl, tr)
    dB = dist.euclidean(tl, bl)
    size_pixels=dA*dB

    scaleFactor=size_object/size_pixels
    return(scaleFactor, img_scale)

